# Copyright 2021-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""train DCGAN and get checkpoint files."""
import argparse
import ast
import os
import time
import datetime
import numpy as np
from mindspore import context
from mindspore import nn, Tensor
from mindspore.train.callback import CheckpointConfig, _InternalCallbackParam, ModelCheckpoint, RunContext
from mindspore.context import ParallelMode
from mindspore.common import set_seed
from mindspore.communication.management import init, get_rank, get_group_size
from src.dataset import create_dataset_imagenet
from src.config import dcgan_imagenet_cfg as cfg
from src.generator import Generator
from src.discriminator import Discriminator
from src.cell import WithLossCellD, WithLossCellG
from src.dcgan import DCGAN
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg')
set_seed(1)

def save_imgs(gen_imgs, idx):
    """
    Save images in 4 * 4 format when training on the modelarts

    Inputs:
        - **gen_imgs** (array) - Images generated by the generator.
        - **idx** (int) - Training epoch.
    """
    for index in range(gen_imgs.shape[0]):
        plt.subplot(4, 4, index + 1)
        gen_imgs[index] = gen_imgs[index] * 127.5 + 127.5
        perm = (1, 2, 0)
        show_imgs = np.transpose(gen_imgs[index], perm)
        sdf = show_imgs.astype(int)
        plt.imshow(sdf)
        plt.axis("off")
    plt.savefig("/cache/images/{}.png".format(idx))


def save_losses(G_losses_list, D_losses_list, idx):
    """
    Save Loss visualization images when training on the modelarts

    Inputs:
        - **G_losses_list** (list) - Generator loss list.
        - **D_losses_list** (list) - Discriminator loss list.
        - **idx** (int) - Training epoch.
    """
    plt.figure(figsize=(10, 5))
    plt.title("Generator and Discriminator Loss During Training")
    plt.plot(G_losses_list, label="G")
    plt.plot(D_losses_list, label="D")
    plt.xlabel("iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.savefig("/cache/losses/{}.png".format(idx))


parser = argparse.ArgumentParser(description='MindSpore dcgan training')
parser.add_argument("--run_modelart", type=ast.literal_eval, default=False,
                    help="Run on modelArt, default is false.")
parser.add_argument("--run_distribute", type=ast.literal_eval, default=False,
                    help="Run distribute, default is false.")
parser.add_argument('--device_target', type=str, default='Ascend', help='GPU or Ascend')
parser.add_argument('--device_id', type=int, default=0, help='device id of Ascend (Default: 0)')
parser.add_argument('--data_url', default=None, help='Directory contains ImageNet-1k dataset.')
parser.add_argument('--train_url', default=None, help='Directory of training output.')
parser.add_argument('--images_url', default=None, help='Location of images outputs.')
parser.add_argument('--losses_url', default=None, help='Location of losses outputs.')
args = parser.parse_args()
run_modelart = args.run_modelart
if run_modelart:
    device_id = int(os.getenv('DEVICE_ID'))
    device_num = int(os.getenv('RANK_SIZE'))
    local_input_url = '/cache/data' + str(device_id)
    local_output_url = '/cache/ckpt' + str(device_id)
    local_images_url = '/cache/images'
    local_losses_url = '/cache/losses'
    context.set_context(mode=context.GRAPH_MODE, device_target="Ascend",
                        save_graphs=False)
    context.set_context(device_id=device_id)

    if device_num > 1:
        init()
        context.set_auto_parallel_context(device_num=device_num,
                                          global_rank=device_id,
                                          parallel_mode=ParallelMode.DATA_PARALLEL,
                                          gradients_mean=True)
        rank = get_rank()
    else:
        rank = 0
    import moxing as mox

    mox.file.copy_parallel(src_url=args.data_url, dst_url=local_input_url)
    mox.file.copy_parallel(src_url=args.images_url, dst_url=local_images_url)
    mox.file.copy_parallel(src_url=args.losses_url, dst_url=local_losses_url)

elif args.run_distribute:
    if args.device_target == 'Ascend':
        device_id = args.device_id
        context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
        context.set_context(device_id=device_id)
        init()
        device_num = 1
        context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
                                          gradients_mean=True)
        local_input_url = args.data_url
        local_output_url = args.train_url
        rank = get_rank()
    elif args.device_target == 'GPU':
        context.set_context(mode=context.GRAPH_MODE, device_target='GPU', enable_graph_kernel=True)
        init()
        rank = get_rank()
        device_num = get_group_size()
        args.device_id = rank
        context.set_auto_parallel_context(device_num=device_num, global_rank=rank,
                                          parallel_mode=ParallelMode.DATA_PARALLEL,
                                          gradients_mean=True)
        local_input_url = args.data_url
        local_output_url = args.train_url
else:
    device_id = args.device_id
    device_target = args.device_target
    context.set_context(mode=context.GRAPH_MODE, device_target=device_target, save_graphs=False)
    context.set_context(device_id=device_id)
    rank = 0
    device_num = 1
    local_input_url = args.data_url
    local_output_url = args.train_url


if __name__ == '__main__':
    start = time.time()
    # Load Dataset
    ds = create_dataset_imagenet(os.path.join(local_input_url), num_parallel_workers=4)

    steps_per_epoch = ds.get_dataset_size()

    # Define Network
    netD = Discriminator()
    netG = Generator()

    criterion = nn.BCELoss(reduction='mean')

    netD_with_criterion = WithLossCellD(netD, netG, criterion)
    netG_with_criterion = WithLossCellG(netD, netG, criterion)

    optimizerD = nn.Adam(netD.trainable_params(), learning_rate=cfg.learning_rate, beta1=cfg.beta1)
    optimizerG = nn.Adam(netG.trainable_params(), learning_rate=cfg.learning_rate, beta1=cfg.beta1)

    myTrainOneStepCellForD = nn.TrainOneStepCell(netD_with_criterion, optimizerD)
    myTrainOneStepCellForG = nn.TrainOneStepCell(netG_with_criterion, optimizerG)

    dcgan = DCGAN(myTrainOneStepCellForD, myTrainOneStepCellForG)
    dcgan.set_train()

    # checkpoint save
    ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch,
                                   keep_checkpoint_max=cfg.epoch_size)
    ckpt_cb = ModelCheckpoint(config=ckpt_config, directory=local_output_url, prefix='dcgan')

    cb_params = _InternalCallbackParam()
    cb_params.train_network = dcgan
    cb_params.batch_num = steps_per_epoch
    cb_params.epoch_num = cfg.epoch_size
    # For each epoch
    cb_params.cur_epoch_num = 0
    cb_params.cur_step_num = 0
    run_context = RunContext(cb_params)
    ckpt_cb.begin(run_context)

    np.random.seed(1)
    fixed_noise = Tensor(np.random.normal(size=(16, cfg.latent_size, 1, 1)).astype("float32"))

    data_loader = ds.create_dict_iterator(output_numpy=True, num_epochs=cfg.epoch_size)
    G_losses = []
    D_losses = []
    # Start Training Loop
    print("Starting Training Loop...")
    for epoch in range(cfg.epoch_size):
        # For each batch in the dataloader
        for i, data in enumerate(data_loader):
            real_data = Tensor(data['image'])
            latent_code = Tensor(data["latent_code"])
            netD_loss, netG_loss = dcgan(real_data, latent_code)
            if i % 50 == 0:
                print("Date time: ", datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), "\tepoch: ", epoch, "/",
                      cfg.epoch_size, "\tstep: ", i, "/", steps_per_epoch, "\tDloss: ", netD_loss, "\tGloss: ",
                      netG_loss)
            D_losses.append(netD_loss.asnumpy())
            G_losses.append(netG_loss.asnumpy())
            cb_params.cur_step_num = cb_params.cur_step_num + 1
        cb_params.cur_epoch_num = cb_params.cur_epoch_num + 1

        if args.device_id == 0 or not args.run_distribute:
            print("================saving model===================")
            ckpt_cb.step_end(run_context)
            if run_modelart:
                fake = netG(fixed_noise)
                print("================saving images===================")
                save_imgs(fake.asnumpy(), epoch + 1)
                print("================saving losses===================")
                save_losses(G_losses, D_losses, epoch + 1)
                mox.file.copy_parallel(src_url=local_images_url, dst_url=args.images_url)
                mox.file.copy_parallel(src_url=local_losses_url, dst_url=args.losses_url)
                mox.file.copy_parallel(src_url=local_output_url, dst_url=args.train_url)
            print("================success================")
    t = time.time() - start
    print("train time:", t)
